import sys
import numpy as np

from tqdm import tqdm, trange

categories = ['O', 'B-LOC', 'I-LOC', 'B-PER', 'I-PER', 'B-ORG', 'I-ORG']
categories_id2label = {i: k for i, k in enumerate(categories)}
categories_label2id = {k: i for i, k in enumerate(categories)}


def trans_entity2tuple(scores):
    '''把tensor转为(样本id, start, end, 实体类型)的tuple用于计算指标
    '''
    batch_entity_ids = set()
    for i, one_samp in enumerate(scores):
        entity_ids = []
        for j, item in enumerate(one_samp):
            flag_tag = categories_id2label[item.item()]
            if flag_tag.startswith('B-'):  # B
                entity_ids.append([i, j, j, flag_tag[2:]])
            elif len(entity_ids) == 0:
                continue
            elif (len(entity_ids[-1]) > 0) and flag_tag.startswith('I-') and (flag_tag[2:]==entity_ids[-1][-1]):  # I
                entity_ids[-1][-2] = j
            elif len(entity_ids[-1]) > 0:
                entity_ids.append([])

        for i in entity_ids:
            if i:
                batch_entity_ids.add(tuple(i))
    return batch_entity_ids


def get_bertcls_post(outputs, batch_size, params, content):
    outputs_size = params['outputs_size'].split("#")
    outputs_size_list = [ [int(size) for size in output_size.split(",")] for output_size in outputs_size]
    outputs = [outputs[i].copy().reshape(-1, outputs_size_list[i][0], outputs_size_list[i][1]) for i in range(len(outputs))]

    npreds = []
    for idx in range(batch_size):
        pred = np.argmax(outputs[0][idx:idx+1], axis=-1)
        label = np.array([int(x) for x in content[idx][1]]).astype(np.int64).reshape(1,-1)
        scores = np.zeros([1, label.shape[1]],dtype=np.int64)#257
        if label.shape[1] > 128:
            scores[:,:128] = pred
        else:
            scores = pred[:,:label.shape[1]]
        attention_mask = label > 0

        # token粒度
        X = ((scores == label) * attention_mask).sum().item()
        Y = (scores > 0).sum().item()
        Z = (label > 0).sum().item()

        # entity粒度
        entity_pred = trans_entity2tuple(scores)
        entity_true = trans_entity2tuple(label)
        X2 = len(entity_pred.intersection(entity_true))
        Y2 = len(entity_pred)
        Z2 = len(entity_true)

        npreds.append([X,Y,Z,X2,Y2,Z2])

    return npreds